
import xml.dom.minidom as xmldom
from scipy.spatial.transform import Rotation
from collections import OrderedDict
import os


from gym import error, spaces
from gym.utils import seeding
import numpy as np
from os import path
import gym

try:
    import mujoco_py
except ImportError as e:
    raise error.DependencyNotInstalled("{}. (HINT: you need to install mujoco_py, and also perform the setup instructions here: https://github.com/openai/mujoco-py/.)".format(e))

DEFAULT_SIZE = 500


def convert_observation_to_space(observation):
    if isinstance(observation, dict):
        space = spaces.Dict(OrderedDict([
            (key, convert_observation_to_space(value))
            for key, value in observation.items()
        ]))
    elif isinstance(observation, np.ndarray):
        low = np.full(observation.shape, -float('inf'), dtype=np.float32)
        high = np.full(observation.shape, float('inf'), dtype=np.float32)
        space = spaces.Box(low, high, dtype=observation.dtype)
    else:
        raise NotImplementedError(type(observation), observation)

    return space



def add_multi_dim_noise(std_list):
    return np.array([np.random.normal(0, std) for std in std_list])


def add_noise_for_qua(qua, eular_noise_params):
    euler = Rotation.from_quat(qua).as_euler('xyz', degrees=True)
    euler += add_multi_dim_noise(
        [eular_noise_params["roll"], eular_noise_params["pitch"], eular_noise_params["yaw"]])
    return Rotation.from_euler('xyz', euler, degrees=True).as_quat()


class MujocoEnv(gym.Env):
    """Superclass for all MuJoCo environments.
    """

    def __init__(self, model_path, frame_skip,
                 qpos_begin,
                 qvel_clip,
                 add_cfrc_ext,
                 noise_params=None):
        if model_path.startswith("/"):
            fullpath = model_path
        else:
            fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path)
        if not path.exists(fullpath):
            raise IOError("File %s does not exist" % fullpath)
        ###
        # added for adding noise
        self.qpos_begin = qpos_begin
        self.qvel_clip = qvel_clip
        self.add_cfrc_ext = add_cfrc_ext
        self.noise_params = noise_params
        ###
        self.frame_skip = frame_skip
        self.model = mujoco_py.load_model_from_path(fullpath)
        self.sim = mujoco_py.MjSim(self.model)
        self.data = self.sim.data
        self.viewer = None
        self._viewers = {}

        self.metadata = {
            'render.modes': ['human', 'rgb_array', 'depth_array'],
            'video.frames_per_second': int(np.round(1.0 / self.dt))
        }

        self.init_qpos = self.sim.data.qpos.ravel().copy()
        self.init_qvel = self.sim.data.qvel.ravel().copy()

        self._set_action_space()

        action = self.action_space.sample()
        observation, _reward, done, _info = self.step(action)
        assert not done

        self._set_observation_space(observation)

        self.seed()

    def _set_action_space(self):
        bounds = self.model.actuator_ctrlrange.copy().astype(np.float32)
        low, high = bounds.T
        self.action_space = spaces.Box(low=low, high=high, dtype=np.float32)
        return self.action_space

    def _set_observation_space(self, observation):
        self.observation_space = convert_observation_to_space(observation)
        return self.observation_space

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    # methods to override:
    # ----------------------------

    def reset_model(self):
        """
        Reset the robot degrees of freedom (qpos and qvel).
        Implement this in each subclass.
        """
        raise NotImplementedError

    def viewer_setup(self):
        """
        This method is called when the viewer is initialized.
        Optionally implement this method, if you need to tinker with camera position
        and so forth.
        """
        pass

    # -----------------------------

    def reset(self):
        self.sim.reset()
        ob = self.reset_model()
        return ob

    def set_state(self, qpos, qvel):
        assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
        old_state = self.sim.get_state()
        new_state = mujoco_py.MjSimState(old_state.time, qpos, qvel,
                                         old_state.act, old_state.udd_state)
        self.sim.set_state(new_state)
        self.sim.forward()

    @property
    def dt(self):
        return self.model.opt.timestep * self.frame_skip

    def do_simulation(self, ctrl, n_frames):
        ### added for adding noise
        if self.noise_params is not None:
            ctrl += add_multi_dim_noise([self.noise_params["action"]] * (len(ctrl)))
        ###
        self.sim.data.ctrl[:] = ctrl
        for _ in range(n_frames):
            self.sim.step()

    def render(self,
               mode='human',
               width=DEFAULT_SIZE,
               height=DEFAULT_SIZE,
               camera_id=None,
               camera_name=None):
        if mode == 'rgb_array' or mode == 'depth_array':
            if camera_id is not None and camera_name is not None:
                raise ValueError("Both `camera_id` and `camera_name` cannot be"
                                 " specified at the same time.")

            no_camera_specified = camera_name is None and camera_id is None
            if no_camera_specified:
                camera_name = 'track'

            if camera_id is None and camera_name in self.model._camera_name2id:
                camera_id = self.model.camera_name2id(camera_name)

            self._get_viewer(mode).render(width, height, camera_id=camera_id)

        if mode == 'rgb_array':
            # window size used for old mujoco-py:
            data = self._get_viewer(mode).read_pixels(width, height, depth=False)
            # original image is upside-down, so flip it
            return data[::-1, :, :]
        elif mode == 'depth_array':
            self._get_viewer(mode).render(width, height)
            # window size used for old mujoco-py:
            # Extract depth part of the read_pixels() tuple
            data = self._get_viewer(mode).read_pixels(width, height, depth=True)[1]
            # original image is upside-down, so flip it
            return data[::-1, :]
        elif mode == 'human':
            self._get_viewer(mode).render()

    def close(self):
        if self.viewer is not None:
            # self.viewer.finish()
            self.viewer = None
            self._viewers = {}

    def _get_viewer(self, mode):
        self.viewer = self._viewers.get(mode)
        if self.viewer is None:
            if mode == 'human':
                self.viewer = mujoco_py.MjViewer(self.sim)
            elif mode == 'rgb_array' or mode == 'depth_array':
                self.viewer = mujoco_py.MjRenderContextOffscreen(self.sim, -1)

            self.viewer_setup()
            self._viewers[mode] = self.viewer
        return self.viewer

    def get_body_com(self, body_name):
        return self.data.get_body_xpos(body_name)

    def state_vector(self):
        return np.concatenate([
            self.sim.data.qpos.flat,
            self.sim.data.qvel.flat
        ])
    ###
    # added for adding noise

    def _get_obs(self):
        if self.noise_params is None:
            return self._get_deterministic_obs()
        else:
            return self._get_noisy_obs()

    def _get_deterministic_obs(self):
        obs_list = [self.sim.data.qpos.flat[self.qpos_begin:]]
        if not self.qvel_clip:
            obs_list.append(self.sim.data.qvel.flat)
        else:
            obs_list.append(np.clip(self.sim.data.qvel.flat, -10, 10))
        if self.add_cfrc_ext:
            obs_list.append(np.clip(self.sim.data.cfrc_ext, -1, 1).flat)
        return np.concatenate(obs_list)

    def parse_mujoco_joint(self):
        """
        0:  free    7-pos   6-vel
        1:
        2:  slide   1-pos   1-vel
        3:  hinge   1-pos   1-vel
        """
        qpos = []
        qvel = []
        for jnt_id, jnt_type in enumerate(self.model.jnt_type):
            jnt_name = self.model.joint_id2name(jnt_id)
            jnt_qpos = self.sim.data.get_joint_qpos(jnt_name)
            jnt_qvel = self.sim.data.get_joint_qvel(jnt_name)
            if jnt_type == 0:
                jnt_qpos[:3] += add_multi_dim_noise([self.noise_params["pos"]["xyz"]] * 3)
                jnt_qpos[3:7] = add_noise_for_qua(jnt_qpos[3:7], self.noise_params["pos"])
                jnt_qvel[:3] += add_multi_dim_noise([self.noise_params["vel"]["xyz"]] * 3)
                jnt_qvel[3:6] += add_multi_dim_noise([self.noise_params["vel"]["rotate"]] * 3)
            elif jnt_type == 1:
                raise NotImplementedError
            elif jnt_type == 2:
                jnt_qpos += add_multi_dim_noise([self.noise_params["pos"]["xyz"]])
                jnt_qvel += add_multi_dim_noise([self.noise_params["vel"]["xyz"]])
            else:
                jnt_qpos += add_multi_dim_noise([self.noise_params["pos"]["hinge"] * np.pi / 180])
                jnt_qvel += add_multi_dim_noise([self.noise_params["vel"]["hinge"]])
            qpos.extend(jnt_qpos)
            qvel.extend(jnt_qvel)

        return qpos, qvel


    def _get_noisy_obs(self):
        qpos, qvel = self.parse_mujoco_joint()
        obs_list = [qpos[self.qpos_begin:]]
        if not self.qvel_clip:
            obs_list.append(qvel)
        else:
            obs_list.append(np.clip(qvel, -10, 10))
        if self.add_cfrc_ext:
            obs_list.append(np.clip(self.sim.data.cfrc_ext, -1, 1).flat)
        return np.concatenate(obs_list)
